import torch.utils.data as data
import os
import csv
import json
import numpy as np
import pandas as pd
import torch
import pdb
import time
import random

from scipy import interpolate

import utils
import config


class ThumosFeature(data.Dataset):
    def __init__(self, data_path, mode, modal, feature_fps, num_segments, sampling, seed=-1, supervision='point'):
        if seed >= 0:
            utils.set_seed(seed)

        self.mode = mode
        self.modal = modal
        self.feature_fps = feature_fps
        self.num_segments = num_segments
        self.sampling = sampling
        self.supervision = supervision

        if self.modal == 'all':
            self.feature_path = []
            for _modal in ['rgb', 'flow']:
                self.feature_path.append(os.path.join(data_path, 'features', self.mode, _modal))
        else:
            self.feature_path = os.path.join(data_path, 'features', self.mode, self.modal)

        split_path = os.path.join(data_path, 'split_{}.txt'.format(self.mode))
        split_file = open(split_path, 'r')
        self.vid_list = []
        for line in split_file:
            self.vid_list.append(line.strip())
        split_file.close()

        self.fps_dict = json.load(open(os.path.join(data_path, 'fps_dict.json')))

        anno_path = os.path.join(data_path, 'gt.json')
        anno_file = open(anno_path, 'r')
        self.anno = json.load(anno_file)
        anno_file.close()
        
        self.class_name_to_idx = dict((v, k) for k, v in config.class_dict.items())        
        self.num_classes = len(self.class_name_to_idx.keys())
        
        if self.supervision == 'point':
            self.point_anno = pd.read_csv(os.path.join(data_path, 'point_gaussian', 'point_labels.csv'))
            
        self.stored_info_all = {'new_dense_anno': [-1] * len(self.vid_list), 'sequence_score': [-1] * len(self.vid_list)}

    def __len__(self):
        return len(self.vid_list)

    def __getitem__(self, index):
        data, vid_num_seg, sample_idx = self.get_data(index)
        label, point_anno, dynamic_segment_weights = self.get_label(index, vid_num_seg, sample_idx)
        sampled_point_anno, sampled_feature = self.dynamic_sample(data, point_anno, dynamic_segment_weights)

        stored_info = {'new_dense_anno': self.stored_info_all['new_dense_anno'][index], 'sequence_score': self.stored_info_all['sequence_score'][index]}

        # return index, data, label, point_anno, stored_info, self.vid_list[index], vid_num_seg
        return index, sampled_feature, label, sampled_point_anno, stored_info, self.vid_list[index], sampled_feature.shape[0]

    def dynamic_sample(self, feature, point_anno, dynamic_segment_weights):
        vid_num_seg = feature.shape[0]
        dynamic_segment_weights_cumsum = np.concatenate((np.zeros((1,), dtype=float), np.cumsum(dynamic_segment_weights)), axis=0)
        max_dynamic_segment_weights_cumsum = np.round(dynamic_segment_weights_cumsum[-1]).astype(int)
        f_upsample = interpolate.interp1d(dynamic_segment_weights_cumsum, np.arange(vid_num_seg + 1), kind='linear',
                                          axis=0, fill_value='extrapolate')
        scale_x = np.linspace(1, max_dynamic_segment_weights_cumsum, max_dynamic_segment_weights_cumsum)
        sampled_time = f_upsample(scale_x)
        f_feature = interpolate.interp1d(np.arange(1, vid_num_seg + 1), feature, kind='linear', axis=0, fill_value='extrapolate')
        sampled_feature = f_feature(sampled_time)

        sampled_point_anno = np.zeros([sampled_feature.shape[0], self.num_classes], dtype=np.float32)
        point_anno_agnostic = point_anno.max(dim=1)[0]
        act_idx_lst = np.where(point_anno_agnostic == 1)[0]
        # print(1)
        for act_idx in act_idx_lst:
            temp_label = point_anno[act_idx]
            temp_idx = act_idx + 1
            mark_idx = np.where((sampled_time >= (temp_idx - 0.51)) & (sampled_time < temp_idx + 0.51))[0]
            start_idx = mark_idx[0]
            end_idx = mark_idx[-1] + 1
            for t_dx in range(start_idx, end_idx):
                sampled_point_anno[t_dx] = temp_label

        return sampled_point_anno, sampled_feature


    def get_data(self, index):
        vid_name = self.vid_list[index]

        vid_num_seg = 0

        if self.modal == 'all':
            rgb_feature = np.load(os.path.join(self.feature_path[0],
                                    vid_name + '.npy')).astype(np.float32)
            flow_feature = np.load(os.path.join(self.feature_path[1],
                                    vid_name + '.npy')).astype(np.float32)

            vid_num_seg = rgb_feature.shape[0]

            if self.sampling == 'random':
                sample_idx = self.random_perturb(vid_num_seg)
            elif self.sampling == 'uniform':
                sample_idx = self.uniform_sampling(vid_num_seg)
            else:
                raise AssertionError('Not supported sampling !')

            rgb_feature = rgb_feature[sample_idx]
            flow_feature = flow_feature[sample_idx]

            feature = np.concatenate((rgb_feature, flow_feature), axis=1)
        else:
            feature = np.load(os.path.join(self.feature_path,
                                    vid_name + '.npy')).astype(np.float32)

            vid_num_seg = feature.shape[0]

            if self.sampling == 'random':
                sample_idx = self.random_perturb(vid_num_seg)
            elif self.sampling == 'uniform':
                sample_idx = self.uniform_sampling(vid_num_seg)
            else:
                raise AssertionError('Not supported sampling !')

            feature = feature[sample_idx]

        return torch.from_numpy(feature), vid_num_seg, sample_idx

    def get_label(self, index, vid_num_seg, sample_idx):
        vid_name = self.vid_list[index]
        anno_list = self.anno['database'][vid_name]['annotations']
        label = np.zeros([self.num_classes], dtype=np.float32)

        classwise_anno = [[]] * self.num_classes

        for _anno in anno_list:
            label[self.class_name_to_idx[_anno['label']]] = 1
            classwise_anno[self.class_name_to_idx[_anno['label']]].append(_anno)

        if self.supervision == 'video':
            return label, torch.Tensor(0)

        elif self.supervision == 'point':
            temp_anno = np.zeros([vid_num_seg, self.num_classes], dtype=np.float32)
            t_factor = self.feature_fps / (self.fps_dict[vid_name] * 16)

            temp_df = self.point_anno[self.point_anno["video_id"] == vid_name][['point', 'class_index']]

            point_anno_lst = []
            for key in temp_df['point'].keys():
                point = temp_df['point'][key]
                class_idx = temp_df['class_index'][key]

                temp_anno[int(point * t_factor)][class_idx] = 1
                point_anno_lst.append(int(point * t_factor))

            point_label = temp_anno[sample_idx, :]

            dynamic_segment_weights = np.ones((vid_num_seg,), dtype=float)
            point_lst = utils.grouping(point_anno_lst)
            for point_anno in point_lst:
                if len(point_anno) > 0:
                    dynamic_segment_weights[point_anno[0]: point_anno[-1] + 1] = 2.5
                    if point_anno[0] - 1 >= 0:
                        dynamic_segment_weights[point_anno[0] - 1] = 1.67
                    if point_anno[-1] + 1 <= vid_num_seg-1:
                        dynamic_segment_weights[point_anno[-1] + 1] = 1.67
            # for point_anno in point_anno_lst:
            #     dynamic_segment_weights[point_anno] = dynamic_segment_weights[point_anno] + 0.67
            #     if 0 < point_anno < vid_num_seg-1:
            #         dynamic_segment_weights[point_anno-1] = dynamic_segment_weights[point_anno-1] + 0.33
            #         dynamic_segment_weights[point_anno+1] = dynamic_segment_weights[point_anno - 1] + 0.33
            #     elif point_anno == 0:
            #         dynamic_segment_weights[point_anno + 1] = dynamic_segment_weights[point_anno - 1] + 0.33
            #     elif point_anno == vid_num_seg-1:
            #         dynamic_segment_weights[point_anno - 1] = dynamic_segment_weights[point_anno - 1] + 0.33


            return label, torch.from_numpy(point_label), dynamic_segment_weights

    def random_perturb(self, length):
        if self.num_segments == length or self.num_segments == -1:
            return np.arange(length).astype(int)
        samples = np.arange(self.num_segments) * length / self.num_segments
        for i in range(self.num_segments):
            if i < self.num_segments - 1:
                if int(samples[i]) != int(samples[i + 1]):
                    samples[i] = np.random.choice(range(int(samples[i]), int(samples[i + 1]) + 1))
                else:
                    samples[i] = int(samples[i])
            else:
                if int(samples[i]) < length - 1:
                    samples[i] = np.random.choice(range(int(samples[i]), length))
                else:
                    samples[i] = int(samples[i])
        return samples.astype(int)


    def uniform_sampling(self, length):
        if length <= self.num_segments or self.num_segments == -1:
            return np.arange(length).astype(int)
        samples = np.arange(self.num_segments) * length / self.num_segments
        samples = np.floor(samples)
        return samples.astype(int)


class ActivityFeature(data.Dataset):
    def __init__(self, data_path, mode, modal, feature_fps, num_segments, sampling, seed=-1, supervision='point'):
        if seed >= 0:
            utils.set_seed(seed)

        self.mode = mode
        self.modal = modal
        self.feature_fps = feature_fps
        self.num_segments = num_segments
        self.sampling = sampling
        self.supervision = supervision

        self.feature_path = os.path.join(data_path, 'features', self.mode)

        split_path = os.path.join(data_path, 'split_{}.txt'.format(self.mode))
        split_file = open(split_path, 'r')
        self.vid_list = []
        for line in split_file:
            self.vid_list.append(line.strip())
        split_file.close()

        # self.fps_dict = json.load(open(os.path.join(data_path, 'fps_dict.json')))

        anno_path = os.path.join(data_path, 'gt.json')
        anno_file = open(anno_path, 'r')
        self.anno = json.load(anno_file)
        anno_file.close()

        self.class_name_to_idx = dict((v, k) for k, v in config.class_dict_A.items())
        self.num_classes = len(self.class_name_to_idx.keys())

        if self.supervision == 'point':
            self.point_anno = pd.read_csv(os.path.join(data_path, 'point_gaussian', 'point_label_ac.csv'))

        self.stored_info_all = {'new_dense_anno': [-1] * len(self.vid_list),
                                'sequence_score': [-1] * len(self.vid_list)}

    def __len__(self):
        return len(self.vid_list)

    def __getitem__(self, index):
        data, vid_num_seg = self.get_data(index)
        label, point_anno, dynamic_segment_weights = self.get_label(index, vid_num_seg)
        sampled_point_anno, sampled_feature = self.dynamic_sample(index, data, point_anno, dynamic_segment_weights)


        stored_info = {'new_dense_anno': self.stored_info_all['new_dense_anno'][index],
                       'sequence_score': self.stored_info_all['sequence_score'][index]}

        # return index, data, label, point_anno, stored_info, self.vid_list[index], vid_num_seg
        return index, sampled_feature, label, sampled_point_anno, stored_info, self.vid_list[index], \
        sampled_feature.shape[0]

    def dynamic_sample(self, index, feature, point_anno, dynamic_segment_weights):
        vid_num_seg = feature.shape[0]
        if vid_num_seg > 1:
            dynamic_segment_weights_cumsum = np.concatenate(
                (np.zeros((1,), dtype=float), np.cumsum(dynamic_segment_weights)), axis=0)
            max_dynamic_segment_weights_cumsum = np.round(dynamic_segment_weights_cumsum[-1]).astype(int)
            f_upsample = interpolate.interp1d(dynamic_segment_weights_cumsum, np.arange(vid_num_seg + 1), kind='linear',
                                              axis=0, fill_value='extrapolate')
            scale_x = np.linspace(1, max_dynamic_segment_weights_cumsum, max_dynamic_segment_weights_cumsum)
            sampled_time = f_upsample(scale_x)
            f_feature = interpolate.interp1d(np.arange(1, vid_num_seg + 1), feature, kind='linear', axis=0,
                                             fill_value='extrapolate')
            sampled_feature = f_feature(sampled_time)
        else:
            sampled_time = np.rint(np.linspace(0, 0, 2))
            dynamic_segment_weights_cumsum = np.concatenate((np.zeros((1,), dtype=float), np.array([0.5, 1.0], dtype=float)), axis=0)
            sampled_feature = feature[sampled_time]


        sampled_point_anno = np.zeros([sampled_feature.shape[0], self.num_classes], dtype=np.float32)
        point_anno_agnostic = point_anno.max(dim=1)[0]
        act_idx_lst = np.where(point_anno_agnostic == 1)[0]
        # print(1)
        for act_idx in act_idx_lst:
            temp_label = point_anno[act_idx]
            temp_idx = act_idx + 1
            mark_idx = np.where((sampled_time >= (temp_idx - 0.51)) & (sampled_time < temp_idx + 0.51))[0]
            start_idx = mark_idx[0]
            end_idx = mark_idx[-1] + 1
            for t_dx in range(start_idx, end_idx):
                sampled_point_anno[t_dx] = temp_label

        sampled_point_anno_agnostic = torch.from_numpy(sampled_point_anno).max(dim=1)[0]
        act_idx_sampled = list(np.where(sampled_point_anno_agnostic == 1)[0])

        vid_num_seg_sampled = sampled_feature.shape[0]
        if self.sampling == 'random':
            sample_idx = self.random_perturb(vid_num_seg_sampled)
        elif self.sampling == 'uniform':
            sample_idx = self.uniform_sampling(vid_num_seg_sampled)
        else:
            raise AssertionError('Not supported sampling !')

        ret = list(set(act_idx_sampled) ^ set(list(sample_idx)))
        act_idx_sampled_inter_ret = list(set(ret).intersection(set(act_idx_sampled)))
        sample_idx_inter_ret = list(set(ret).intersection(set(sample_idx)))
        for i in range(len(act_idx_sampled_inter_ret)):
            randint = random.randint(0, len(sample_idx_inter_ret)-1)
            sample_idx_inter_ret.pop(randint)

        sample_idx_new = list(set(act_idx_sampled).union(set(sample_idx_inter_ret)))
        sample_idx_new.sort()
        if len(sample_idx_new) < len(sample_idx):
            supple_num = len(sample_idx) - len(sample_idx_new)
            for i in range(supple_num):
                randint2 = random.randint(0, len(sample_idx_new) - 1)
                sample_idx_new.append(sample_idx_new[randint2])
        sample_idx_new.sort()
        # assert len(sample_idx_new) == 150, "{}".format(self.vid_list[index])

        sampled_feature = sampled_feature[sample_idx_new]
        sampled_point_anno = sampled_point_anno[sample_idx_new]

        return sampled_point_anno, sampled_feature

    def get_data(self, index):
        vid_name = self.vid_list[index]
        feature = np.load(os.path.join(self.feature_path,
                                       vid_name + '.npy')).astype(np.float32)

        vid_num_seg = feature.shape[0]
        return torch.from_numpy(feature), vid_num_seg

    def get_label(self, index, vid_num_seg):
        vid_name = self.vid_list[index]
        anno_list = self.anno['database'][vid_name]['annotations']
        label = np.zeros([self.num_classes], dtype=np.float32)

        for _anno in anno_list:
            label[self.class_name_to_idx[_anno['label']]] = 1

        if self.supervision == 'video':
            return label, torch.Tensor(0)

        elif self.supervision == 'point':
            temp_anno = np.zeros([vid_num_seg, self.num_classes], dtype=np.float32)
            fps_vid = self.anno['database'][vid_name]['fps']
            t_factor = self.feature_fps / (fps_vid * 16)

            temp_df = self.point_anno[self.point_anno["video_id"] == vid_name][['point', 'class_index']]

            point_anno_lst = []
            for key in temp_df['point'].keys():
                point = temp_df['point'][key]
                class_idx = temp_df['class_index'][key]

                temp_idx = int(point * t_factor)
                if temp_idx >= vid_num_seg:
                    temp_idx = vid_num_seg - 1
                temp_anno[temp_idx][class_idx] = 1
                point_anno_lst.append(temp_idx)

            point_label = temp_anno

            dynamic_segment_weights = np.ones((vid_num_seg,), dtype=float)
            point_lst = utils.grouping(point_anno_lst)
            for point_anno in point_lst:
                if len(point_anno) > 0:
                    dynamic_segment_weights[point_anno[0]: point_anno[-1] + 1] = 2.5
                    if point_anno[0] - 1 >= 0:
                        dynamic_segment_weights[point_anno[0] - 1] = 1.67
                    if point_anno[-1] + 1 <= vid_num_seg - 1:
                        dynamic_segment_weights[point_anno[-1] + 1] = 1.67

            return label, torch.from_numpy(point_label), dynamic_segment_weights

    def random_perturb(self, length):
        if self.num_segments == length or self.num_segments == -1:
            return np.arange(length).astype(int)
        samples = np.arange(self.num_segments) * length / self.num_segments
        for i in range(self.num_segments):
            if i < self.num_segments - 1:
                if int(samples[i]) != int(samples[i + 1]):
                    samples[i] = np.random.choice(range(int(samples[i]), int(samples[i + 1]) + 1))
                else:
                    samples[i] = int(samples[i])
            else:
                if int(samples[i]) < length - 1:
                    samples[i] = np.random.choice(range(int(samples[i]), length))
                else:
                    samples[i] = int(samples[i])
        return samples.astype(int)

    def uniform_sampling(self, length):
        if self.num_segments == -1:
            return np.arange(length).astype(int)
        samples = np.arange(self.num_segments) * length / self.num_segments
        samples = np.floor(samples)
        return samples.astype(int)


class GTEAFeature(data.Dataset):
    def __init__(self, data_path, mode, modal, feature_fps, num_segments, sampling, seed=-1, supervision='point'):
        if seed >= 0:
            utils.set_seed(seed)

        self.mode = mode
        self.modal = modal
        self.feature_fps = feature_fps
        self.num_segments = num_segments
        self.sampling = sampling
        self.supervision = supervision

        self.classlist = np.load(os.path.join(data_path, 'classlist.npy'), allow_pickle=True)
        self.subset = np.load(os.path.join(data_path, 'subset.npy'), allow_pickle=True)
        self.subset = [bytes.decode(v) for v in self.subset]
        self.d_idx = []
        if 'train' in self.mode:
            self.d_idx = get_index(self.subset, 'training')
        elif 'test' in self.mode:
            self.d_idx = get_index(self.subset, 'validation')

        self.feature = np.load(os.path.join(data_path, 'features', 'GTEA-I3D-JOINTFeatures.npy'), allow_pickle=True)[self.d_idx]
        self.duration = np.load(os.path.join(data_path, 'duration.npy'), allow_pickle=True)[self.d_idx]
        self.labels = np.load(os.path.join(data_path, 'labels.npy'), allow_pickle=True)[self.d_idx]
        self.labels_all = np.load(os.path.join(data_path, 'labels_all.npy'), allow_pickle=True)[self.d_idx]
        self.segments = np.load(os.path.join(data_path, 'segments.npy'), allow_pickle=True)[self.d_idx]
        self.videoname = np.load(os.path.join(data_path, 'videoname.npy'), allow_pickle=True)[self.d_idx]
        self.videoname = [bytes.decode(v) for v in self.videoname]

        anno_path = os.path.join(data_path, 'gtea_gt.json')
        anno_file = open(anno_path, 'r')
        self.anno = json.load(anno_file)
        anno_file.close()

        self.class_name_to_idx = dict((v, k) for k, v in config.class_dict_gtea.items())
        self.num_classes = len(self.class_name_to_idx.keys())

        if self.supervision == 'point':
            self.point_anno = pd.read_csv(os.path.join(data_path, 'point_paussian', 'point_labels_gtea.csv'))

        self.stored_info_all = {'new_dense_anno': [-1] * len(self.d_idx),
                                'sequence_score': [-1] * len(self.d_idx)}

    def __len__(self):
        return len(self.d_idx)

    def __getitem__(self, index):
        data, vid_num_seg, sample_idx = self.get_data(index)
        label, point_anno, dynamic_segment_weights = self.get_label(index, vid_num_seg, sample_idx)
        sampled_point_anno, sampled_feature = self.dynamic_sample(data, point_anno, dynamic_segment_weights)

        stored_info = {'new_dense_anno': self.stored_info_all['new_dense_anno'][index],
                       'sequence_score': self.stored_info_all['sequence_score'][index]}

        # return index, data, label, point_anno, stored_info, self.vid_list[index], vid_num_seg
        return index, sampled_feature, label, sampled_point_anno, stored_info, self.videoname[index], \
        sampled_feature.shape[0]

    def dynamic_sample(self, feature, point_anno, dynamic_segment_weights):
        vid_num_seg = feature.shape[0]
        dynamic_segment_weights_cumsum = np.concatenate(
            (np.zeros((1,), dtype=float), np.cumsum(dynamic_segment_weights)), axis=0)
        max_dynamic_segment_weights_cumsum = np.round(dynamic_segment_weights_cumsum[-1]).astype(int)
        f_upsample = interpolate.interp1d(dynamic_segment_weights_cumsum, np.arange(vid_num_seg + 1), kind='linear',
                                          axis=0, fill_value='extrapolate')
        scale_x = np.linspace(1, max_dynamic_segment_weights_cumsum, max_dynamic_segment_weights_cumsum)
        sampled_time = f_upsample(scale_x)
        f_feature = interpolate.interp1d(np.arange(1, vid_num_seg + 1), feature, kind='linear', axis=0,
                                         fill_value='extrapolate')
        sampled_feature = f_feature(sampled_time)

        sampled_point_anno = np.zeros([sampled_feature.shape[0], self.num_classes], dtype=np.float32)
        point_anno_agnostic = point_anno.max(dim=1)[0]
        act_idx_lst = np.where(point_anno_agnostic == 1)[0]
        # print(1)
        for act_idx in act_idx_lst:
            temp_label = point_anno[act_idx]
            temp_idx = act_idx + 1
            mark_idx = np.where((sampled_time >= (temp_idx - 0.51)) & (sampled_time < temp_idx + 0.51))[0]
            start_idx = mark_idx[0]
            end_idx = mark_idx[-1] + 1
            for t_dx in range(start_idx, end_idx):
                sampled_point_anno[t_dx] = temp_label

        return sampled_point_anno, sampled_feature

    def get_data(self, index):

        vid_name = self.videoname[index]

        feature = self.feature[index]
        vid_num_seg = feature.shape[0]
        if self.sampling == 'random':
            sample_idx = self.random_perturb(vid_num_seg)
        elif self.sampling == 'uniform':
            sample_idx = self.uniform_sampling(vid_num_seg)
        else:
            raise AssertionError('Not supported sampling !')

        feature = feature[sample_idx]

        return torch.from_numpy(feature), vid_num_seg, sample_idx

    def get_label(self, index, vid_num_seg, sample_idx):
        vid_name = self.videoname[index]
        anno_list = self.anno['database'][vid_name]['annotations']
        fps = self.anno['database'][vid_name]['fps']
        label = np.zeros([self.num_classes], dtype=np.float32)

        for _anno in anno_list:
            label[self.class_name_to_idx[_anno['label']]] = 1

        if self.supervision == 'video':
            return label, torch.Tensor(0)

        elif self.supervision == 'point':
            temp_anno = np.zeros([vid_num_seg, self.num_classes], dtype=np.float32)
            t_factor = self.feature_fps / (fps * 16)

            temp_df = self.point_anno[self.point_anno["video_id"] == vid_name][['point', 'class']]

            point_anno_lst = []
            for key in temp_df['point'].keys():
                point = temp_df['point'][key]
                cls = temp_df['class'][key]
                class_idx = self.class_name_to_idx[cls]

                temp_anno[int(point * t_factor)][class_idx] = 1
                point_anno_lst.append(int(point * t_factor))

            point_label = temp_anno[sample_idx, :]

            dynamic_segment_weights = np.ones((vid_num_seg,), dtype=float)
            point_lst = utils.grouping(point_anno_lst)
            for point_anno in point_lst:
                if len(point_anno) > 0:
                    dynamic_segment_weights[point_anno[0]: point_anno[-1] + 1] = 2.5
                    if point_anno[0] - 1 >= 0:
                        dynamic_segment_weights[point_anno[0] - 1] = 1.67
                    if point_anno[-1] + 1 <= vid_num_seg - 1:
                        dynamic_segment_weights[point_anno[-1] + 1] = 1.67

            return label, torch.from_numpy(point_label), dynamic_segment_weights

    def random_perturb(self, length):
        if self.num_segments == length or self.num_segments == -1:
            return np.arange(length).astype(int)
        samples = np.arange(self.num_segments) * length / self.num_segments
        for i in range(self.num_segments):
            if i < self.num_segments - 1:
                if int(samples[i]) != int(samples[i + 1]):
                    samples[i] = np.random.choice(range(int(samples[i]), int(samples[i + 1]) + 1))
                else:
                    samples[i] = int(samples[i])
            else:
                if int(samples[i]) < length - 1:
                    samples[i] = np.random.choice(range(int(samples[i]), length))
                else:
                    samples[i] = int(samples[i])
        return samples.astype(int)

    def uniform_sampling(self, length):
        if length <= self.num_segments or self.num_segments == -1:
            return np.arange(length).astype(int)
        samples = np.arange(self.num_segments) * length / self.num_segments
        samples = np.floor(samples)
        return samples.astype(int)


class BEOIDFeature(data.Dataset):
    def __init__(self, data_path, mode, modal, feature_fps, num_segments, sampling, seed=-1, supervision='point'):
        if seed >= 0:
            utils.set_seed(seed)

        self.mode = mode
        self.modal = modal
        self.feature_fps = feature_fps
        self.num_segments = num_segments
        self.sampling = sampling
        self.supervision = supervision

        self.classlist = np.load(os.path.join(data_path, 'classlist.npy'), allow_pickle=True)
        self.subset = np.load(os.path.join(data_path, 'subset.npy'), allow_pickle=True)
        self.subset = [bytes.decode(v) for v in self.subset]
        self.d_idx = []
        if 'train' in self.mode:
            self.d_idx = get_index(self.subset, 'training')
        elif 'test' in self.mode:
            self.d_idx = get_index(self.subset, 'validation')

        self.feature = np.load(os.path.join(data_path, 'features', 'BEOID-I3D-JOINTFeatures.npy'), allow_pickle=True)[self.d_idx]
        self.duration = np.load(os.path.join(data_path, 'duration.npy'), allow_pickle=True)[self.d_idx]
        self.labels = np.load(os.path.join(data_path, 'labels.npy'), allow_pickle=True)[self.d_idx]
        self.labels_all = np.load(os.path.join(data_path, 'labels_all.npy'), allow_pickle=True)[self.d_idx]
        self.segments = np.load(os.path.join(data_path, 'segments.npy'), allow_pickle=True)[self.d_idx]
        self.videoname = np.load(os.path.join(data_path, 'videoname.npy'), allow_pickle=True)[self.d_idx]
        self.videoname = [bytes.decode(v) for v in self.videoname]

        anno_path = os.path.join(data_path, 'beoid_gt.json')
        anno_file = open(anno_path, 'r')
        self.anno = json.load(anno_file)
        anno_file.close()

        self.class_name_to_idx = dict((v, k) for k, v in config.class_dict_beoid.items())
        self.num_classes = len(self.class_name_to_idx.keys())

        if self.supervision == 'point':
            self.point_anno = pd.read_csv(os.path.join(data_path, 'point_paussian', 'point_labels_beoid.csv'))

        self.stored_info_all = {'new_dense_anno': [-1] * len(self.d_idx),
                                'sequence_score': [-1] * len(self.d_idx)}

    def __len__(self):
        return len(self.d_idx)

    def __getitem__(self, index):
        data, vid_num_seg, sample_idx = self.get_data(index)
        label, point_anno, dynamic_segment_weights = self.get_label(index, vid_num_seg, sample_idx)
        sampled_point_anno, sampled_feature = self.dynamic_sample(data, point_anno, dynamic_segment_weights)

        stored_info = {'new_dense_anno': self.stored_info_all['new_dense_anno'][index],
                       'sequence_score': self.stored_info_all['sequence_score'][index]}

        # return index, data, label, point_anno, stored_info, self.vid_list[index], vid_num_seg
        return index, sampled_feature, label, sampled_point_anno, stored_info, self.videoname[index], \
        sampled_feature.shape[0]

    def dynamic_sample(self, feature, point_anno, dynamic_segment_weights):
        vid_num_seg = feature.shape[0]
        dynamic_segment_weights_cumsum = np.concatenate(
            (np.zeros((1,), dtype=float), np.cumsum(dynamic_segment_weights)), axis=0)
        max_dynamic_segment_weights_cumsum = np.round(dynamic_segment_weights_cumsum[-1]).astype(int)
        f_upsample = interpolate.interp1d(dynamic_segment_weights_cumsum, np.arange(vid_num_seg + 1), kind='linear',
                                          axis=0, fill_value='extrapolate')
        scale_x = np.linspace(1, max_dynamic_segment_weights_cumsum, max_dynamic_segment_weights_cumsum)
        sampled_time = f_upsample(scale_x)
        f_feature = interpolate.interp1d(np.arange(1, vid_num_seg + 1), feature, kind='linear', axis=0,
                                         fill_value='extrapolate')
        sampled_feature = f_feature(sampled_time)

        sampled_point_anno = np.zeros([sampled_feature.shape[0], self.num_classes], dtype=np.float32)
        point_anno_agnostic = point_anno.max(dim=1)[0]
        act_idx_lst = np.where(point_anno_agnostic == 1)[0]
        # print(1)
        for act_idx in act_idx_lst:
            temp_label = point_anno[act_idx]
            temp_idx = act_idx + 1
            mark_idx = np.where((sampled_time >= (temp_idx - 0.51)) & (sampled_time < temp_idx + 0.51))[0]
            start_idx = mark_idx[0]
            end_idx = mark_idx[-1] + 1
            for t_dx in range(start_idx, end_idx):
                sampled_point_anno[t_dx] = temp_label

        return sampled_point_anno, sampled_feature

    def get_data(self, index):

        vid_name = self.videoname[index]

        feature = self.feature[index]
        vid_num_seg = feature.shape[0]
        if self.sampling == 'random':
            sample_idx = self.random_perturb(vid_num_seg)
        elif self.sampling == 'uniform':
            sample_idx = self.uniform_sampling(vid_num_seg)
        else:
            raise AssertionError('Not supported sampling !')

        feature = feature[sample_idx]

        return torch.from_numpy(feature), vid_num_seg, sample_idx

    def get_label(self, index, vid_num_seg, sample_idx):
        vid_name = self.videoname[index]
        anno_list = self.anno['database'][vid_name]['annotations']
        fps = self.anno['database'][vid_name]['fps']
        label = np.zeros([self.num_classes], dtype=np.float32)

        for _anno in anno_list:
            label[self.class_name_to_idx[_anno['label']]] = 1

        if self.supervision == 'video':
            return label, torch.Tensor(0)

        elif self.supervision == 'point':
            temp_anno = np.zeros([vid_num_seg, self.num_classes], dtype=np.float32)
            t_factor = self.feature_fps / (fps * 16)

            temp_df = self.point_anno[self.point_anno["video_id"] == vid_name][['point', 'class']]

            point_anno_lst = []
            for key in temp_df['point'].keys():
                point = temp_df['point'][key]
                cls = temp_df['class'][key]
                class_idx = self.class_name_to_idx[cls]

                temp_anno[int(point * t_factor)][class_idx] = 1
                point_anno_lst.append(int(point * t_factor))

            point_label = temp_anno[sample_idx, :]

            dynamic_segment_weights = np.ones((vid_num_seg,), dtype=float)
            point_lst = utils.grouping(point_anno_lst)
            for point_anno in point_lst:
                if len(point_anno) > 0:
                    dynamic_segment_weights[point_anno[0]: point_anno[-1] + 1] = 2.5
                    if point_anno[0] - 1 >= 0:
                        dynamic_segment_weights[point_anno[0] - 1] = 1.67
                    if point_anno[-1] + 1 <= vid_num_seg - 1:
                        dynamic_segment_weights[point_anno[-1] + 1] = 1.67

            return label, torch.from_numpy(point_label), dynamic_segment_weights

    def random_perturb(self, length):
        if self.num_segments == length or self.num_segments == -1:
            return np.arange(length).astype(int)
        samples = np.arange(self.num_segments) * length / self.num_segments
        for i in range(self.num_segments):
            if i < self.num_segments - 1:
                if int(samples[i]) != int(samples[i + 1]):
                    samples[i] = np.random.choice(range(int(samples[i]), int(samples[i + 1]) + 1))
                else:
                    samples[i] = int(samples[i])
            else:
                if int(samples[i]) < length - 1:
                    samples[i] = np.random.choice(range(int(samples[i]), length))
                else:
                    samples[i] = int(samples[i])
        return samples.astype(int)

    def uniform_sampling(self, length):
        if length <= self.num_segments or self.num_segments == -1:
            return np.arange(length).astype(int)
        samples = np.arange(self.num_segments) * length / self.num_segments
        samples = np.floor(samples)
        return samples.astype(int)


def build_dataset(dataset_name, data_path, mode, modal, feature_fps, num_segments, sampling, seed=-1, supervision='point'):
    if 'thumos14' in dataset_name:
        return ThumosFeature(data_path, mode, modal, feature_fps, num_segments, sampling, seed, supervision)
    elif 'activitynet13' in dataset_name:
        return ActivityFeature(data_path, mode, modal, feature_fps, num_segments, sampling, seed, supervision)
    elif 'GTEA' in dataset_name:
        return GTEAFeature(data_path, mode, modal, feature_fps, num_segments, sampling, seed, supervision)
    elif 'BEOID' in dataset_name:
        return BEOIDFeature(data_path, mode, modal, feature_fps, num_segments, sampling, seed, supervision)


def get_index(lst=None, item=''):
    tmp = []
    tag = 0
    for i in lst:
        if i == item:
           tmp.append(tag)
        tag += 1
    return tmp
